// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <mutex>

#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
#include "helper.h"
#include "paddle/extension.h"

/**
 * Helper function for checking CUTLASS errors
 */
#define CUTLASS_CHECK(status)                                                  \
  {                                                                            \
    cutlass::Status error = status;                                            \
    PD_CHECK(error == cutlass::Status::kSuccess,                               \
             cutlassGetStatusString(error));                                   \
  }

/**
 * A wrapper for a kernel that is used to guard against compilation on
 * architectures that will never use the kernel. The purpose of this is to
 * reduce the size of the compiled binary.
 * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
 * into code that will be executed on the device where it is defined.
 */
template <typename Kernel> struct enable_sm90_or_later : Kernel {
  template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
    Kernel::operator()(std::forward<Args>(args)...);
#endif
  }
};

template <paddle::DataType D> class CutlassDtypeTraits;

template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
  typedef float DataType;
  typedef float data_t;
};

template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
  typedef cutlass::half_t DataType;
  typedef paddle::float16 data_t;
};

template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
  typedef cutlass::bfloat16_t DataType;
  typedef paddle::bfloat16 data_t;
};

class CutlassGemmConfigMannager {
public:
  static CutlassGemmConfigMannager &getInstance() {
    static CutlassGemmConfigMannager instance;
    return instance;
  }

  CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
  CutlassGemmConfigMannager &
  operator=(const CutlassGemmConfigMannager &) = delete;

  void up_date_configs(const nlohmann::json &j) {
    std::lock_guard<std::mutex> lock(mutex_);
    for (auto it = j.begin(); it != j.end(); ++it) {
      json_[it.key()] = it.value();
    }
  }

  nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
    if (!load_initialized_) {
      std::ifstream file(config_file_path);
      if (!file.good()) {
        throw std::runtime_error(
            "cutlass gemm_best_config can not be found, please set "
            "gemm_best_config'path as "
            "FLAGS_use_cutlass_device_best_config_path, or unset "
            "FLAGS_use_cutlass_device_best_config_path to tune "
            "gemm_best_config");
      }
      json_ = readJsonFromFile(config_file_path);
      load_initialized_ = true;
      save_initialized_ = false;
    }
    return &json_;
  }

private:
  void save_gemm_best_configs_(const std::string &config_file_path) {
    std::ifstream file(config_file_path);
    if (!file.good()) {
      std::ofstream new_file(config_file_path);
      new_file << json_.dump(4);
      new_file.close();
    } else {
      nlohmann::json old_json = readJsonFromFile(config_file_path);
      for (auto it = json_.begin(); it != json_.end(); ++it) {
        old_json[it.key()] = it.value();
      }
      json_ = old_json;
      std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
      new_file << json_.dump(4);
      new_file.close();
      file.close();
    }
    return;
  }

  CutlassGemmConfigMannager()
      : json_(nullptr), load_initialized_(false), save_initialized_(true) {}
  ~CutlassGemmConfigMannager() {
    std::lock_guard<std::mutex> lock(mutex_);
    if (save_initialized_) {
      std::string config_file_path = "fp8_fuse_gemm_config.json";
      save_gemm_best_configs_(config_file_path);
    }
    save_initialized_ = true;
    load_initialized_ = false;
    json_.clear();
  }
  mutable std::mutex mutex_;
  nlohmann::json json_;
  bool load_initialized_;
  bool save_initialized_;
};
